{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/mymusise/ChatGLM-Tuning.git\n",
    "%cd  ChatGLM-Tuning\n",
    "!pip install -r requirements.txt "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python cover_alpaca2jsonl.py \\\n",
    "    --data_path data/alpaca_data.json \\\n",
    "    --save_path data/alpaca_data.jsonl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python tokenize_dataset_rows.py \\\n",
    "    --jsonl_path data/alpaca_data.jsonl \\\n",
    "    --save_path data/alpaca \\\n",
    "    --max_seq_length 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mymusise/pro/stable-diffusion-webui/venv/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
      "Overriding torch_dtype=None with `torch_dtype=torch.float16` due to requirements of `bitsandbytes` to enable model loading in mixed int8. Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===================================BUG REPORT===================================\n",
      "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
      "================================================================================\n",
      "CUDA SETUP: CUDA runtime path found: /usr/local/cuda-11.6/lib64/libcudart.so\n",
      "CUDA SETUP: Highest compute capability among GPUs detected: 8.9\n",
      "CUDA SETUP: Detected CUDA version 116\n",
      "CUDA SETUP: Loading binary /home/mymusise/pro/stable-diffusion-webui/venv/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cuda116.so...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 8/8 [00:07<00:00,  1.04it/s]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModel, TrainingArguments, AutoConfig\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from peft import get_peft_model, LoraConfig, TaskType\n",
    "\n",
    "\n",
    "class CastOutputToFloat(nn.Sequential):\n",
    "    def forward(self, x): return super().forward(x).to(torch.float32)\n",
    "\n",
    "\n",
    "model = AutoModel.from_pretrained(\"THUDM/chatglm-6b\", load_in_8bit=True, trust_remote_code=True, device_map='auto')\n",
    "model.supports_gradient_checkpointing = True\n",
    "model.gradient_checkpointing_enable()\n",
    "model.enable_input_require_grads()\n",
    "model.lm_head = CastOutputToFloat(model.lm_head)\n",
    "model.config.use_cache = False  # silence the warnings. Please re-enable for inference!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"THUDM/chatglm-6b\", trust_remote_code=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test before finetune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Instruction: Give three tips for staying healthy.\n",
      "Answer: I'm sorry, but I'm not sure what you're asking for. Could you please provide more context or clarify your question?\n",
      "### 1.Answer:\n",
      " 1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n",
      "2. Exercise regularly to keep your body active and strong. \n",
      "3. Get enough sleep and maintain a consistent sleep schedule. \n",
      "\n",
      "\n",
      "Instruction: What are the three primary colors?\n",
      "Answer: The three primary colors in painting are red, blue, and green. These colors are often used in combination to create more complex and vibrant colors.\n",
      "### 2.Answer:\n",
      " The three primary colors are red, blue, and yellow. \n",
      "\n",
      "\n",
      "Instruction: Describe the structure of an atom.\n",
      "Answer: The原子是构成物质的基本单位,由带正电荷的质子和不带电荷的电子组成。原子核由带正电荷的质子和不带电荷的中子组成,它们之间通过核力相互吸引。\n",
      "\n",
      "原子的化学性质取决于其组成和结构,以及化学反应中所涉及的因素。例如,氧原子是化学反应中最常见的原子之一,因为它具有两个电子,可以与许多其他原子形成化合物。\n",
      "\n",
      "物质是由原子和分子组成的,而原子是物质的基本单位。了解原子的结构、性质以及化学反应,可以帮助我们更好地理解物质世界,并更好地利用这些物质来制造产品、治疗疾病。\n",
      "### 3.Answer:\n",
      " An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom. \n",
      "\n",
      "\n",
      "Instruction: How can we reduce air pollution?\n",
      "Answer: I'm sorry, but I'm not sure what you're asking. Could you please provide more context or clarify your question?\n",
      "### 4.Answer:\n",
      " There are a number of ways to reduce air pollution, such as shifting to renewable energy sources, encouraging the use of public transportation, prohibiting the burning of fossil fuels, implementing policies to reduce emissions from industrial sources, and implementing vehicle emissions standards. Additionally, individuals can do their part to reduce air pollution by reducing car use, avoiding burning materials such as wood, and changing to energy efficient appliances. \n",
      "\n",
      "\n",
      "Instruction: Describe a time when you had to make a difficult decision.\n",
      "Answer: I'm sorry, but I'm not sure what you're asking for. Could you please provide more context or clarify your question?\n",
      "### 5.Answer:\n",
      " I had to make a difficult decision when I was working as a project manager at a construction company. I was in charge of a project that needed to be completed by a certain date in order to meet the client’s expectations. However, due to unexpected delays, we were not able to meet the deadline and so I had to make a difficult decision. I decided to extend the deadline, but I had to stretch the team’s resources even further and increase the budget. Although it was a risky decision, I ultimately decided to go ahead with it to ensure that the project was completed on time and that the client’s expectations were met. The project was eventually successfully completed and this was seen as a testament to my leadership and decision-making abilities. \n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from cover_alpaca2jsonl import format_example\n",
    "import json\n",
    "\n",
    "\n",
    "instructions = json.load(open(\"data/alpaca_data.json\"))\n",
    "\n",
    "\n",
    "with torch.no_grad():\n",
    "    for idx, item in enumerate(instructions[:5]):\n",
    "        feature = format_example(item)\n",
    "        input_text = feature[\"context\"]\n",
    "        input_ids = tokenizer.encode(input_text, return_tensors='pt')\n",
    "        out = model.generate(\n",
    "            input_ids=input_ids,\n",
    "            max_length=150,\n",
    "            temperature=0\n",
    "        )\n",
    "        answer = tokenizer.decode(out[0])\n",
    "        print(answer)\n",
    "        item['infer_answer'] = answer\n",
    "        print(f\"### {idx+1}.Answer:\\n\", item.get('output'), '\\n\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "peft_config = LoraConfig(\n",
    "    task_type=TaskType.CAUSAL_LM, inference_mode=False,\n",
    "    r=8,\n",
    "    lora_alpha=32, lora_dropout=0.1,\n",
    ")\n",
    "\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.is_parallelizable = True\n",
    "model.model_parallel = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "dataset_path = \"data/alpaca/\"\n",
    "\n",
    "dataset = datasets.load_from_disk(dataset_path)\n",
    "\n",
    "train_num = 500\n",
    "\n",
    "mini_train_dataset = datasets.Dataset.from_dict(dataset[:train_num])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Trainer, HfArgumentParser\n",
    "\n",
    "\n",
    "def data_collator(features: list) -> dict:\n",
    "    len_ids = [len(feature[\"input_ids\"]) for feature in features]\n",
    "    longest = max(len_ids)\n",
    "    input_ids = []\n",
    "    labels_list = []\n",
    "    for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):\n",
    "        ids = feature[\"input_ids\"]\n",
    "        seq_len = feature[\"seq_len\"]\n",
    "        labels = (\n",
    "            [-100] * (seq_len - 1) + ids[(seq_len - 1) :] + [-100] * (longest - ids_l)\n",
    "        )\n",
    "        ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)\n",
    "        _ids = torch.LongTensor(ids)\n",
    "        labels_list.append(torch.LongTensor(labels))\n",
    "        input_ids.append(_ids)\n",
    "    input_ids = torch.stack(input_ids)\n",
    "    labels = torch.stack(labels_list)\n",
    "    return {\n",
    "        \"input_ids\": input_ids,\n",
    "        \"labels\": labels,\n",
    "    }\n",
    "\n",
    "class ModifiedTrainer(Trainer):\n",
    "\n",
    "    def compute_loss(self, model, inputs, return_outputs=False):\n",
    "        return model(\n",
    "            input_ids=inputs[\"input_ids\"],\n",
    "            labels=inputs[\"labels\"],\n",
    "        ).loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1500' max='1500' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1500/1500 07:54, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>2.091900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>1.787700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>150</td>\n",
       "      <td>2.020000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>1.673600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>250</td>\n",
       "      <td>1.978900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>1.534500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>350</td>\n",
       "      <td>1.611100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>1.620300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>450</td>\n",
       "      <td>1.778000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>1.610000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>550</td>\n",
       "      <td>1.493800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>1.479100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>650</td>\n",
       "      <td>1.205300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>1.375400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>750</td>\n",
       "      <td>1.489700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>1.614800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>850</td>\n",
       "      <td>1.405900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>1.440500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>950</td>\n",
       "      <td>1.468200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>1.071200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1050</td>\n",
       "      <td>1.137700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.942300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1150</td>\n",
       "      <td>1.184100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.970900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1250</td>\n",
       "      <td>1.262800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>1.108000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1350</td>\n",
       "      <td>1.203000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>1.265900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1450</td>\n",
       "      <td>1.715500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>1.326100</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mymusise/pro/stable-diffusion-webui/venv/lib/python3.8/site-packages/bitsandbytes/autograd/_functions.py:298: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
      "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
      "/home/mymusise/pro/stable-diffusion-webui/venv/lib/python3.8/site-packages/bitsandbytes/autograd/_functions.py:298: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
      "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=1500, training_loss=1.4622032979329427, metrics={'train_runtime': 474.9934, 'train_samples_per_second': 3.158, 'train_steps_per_second': 3.158, 'total_flos': 3781851053211648.0, 'train_loss': 1.4622032979329427, 'epoch': 3.0})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_args = TrainingArguments(\n",
    "    \"output\",\n",
    "    fp16 =True,\n",
    "    gradient_accumulation_steps=1,\n",
    "    per_device_train_batch_size = 1,\n",
    "    learning_rate = 1e-4,\n",
    "    max_steps=1500,\n",
    "    logging_steps=50,\n",
    "    remove_unused_columns=False,\n",
    "    seed=0,\n",
    "    data_seed=0,\n",
    "    group_by_length=False,\n",
    ")\n",
    "\n",
    "\n",
    "trainer = ModifiedTrainer(\n",
    "    model=model,\n",
    "    train_dataset=mini_train_dataset,\n",
    "    args=training_args,\n",
    "    data_collator=data_collator,\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test After finetune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mymusise/pro/stable-diffusion-webui/venv/lib/python3.8/site-packages/transformers-4.27.0.dev0-py3.8.egg/transformers/generation/utils.py:1374: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Instruction: Give three tips for staying healthy.\n",
      "Answer: 1. Eat a balanced diet. \n",
      "2. Get regular exercise. \n",
      "3. Stay hydrated by drinking plenty of water.\n",
      "### 1.Answer:\n",
      " 1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n",
      "2. Exercise regularly to keep your body active and strong. \n",
      "3. Get enough sleep and maintain a consistent sleep schedule. \n",
      "\n",
      "\n",
      "Instruction: What are the three primary colors?\n",
      "Answer: The three primary colors are red, blue, and yellow.\n",
      "### 2.Answer:\n",
      " The three primary colors are red, blue, and yellow. \n",
      "\n",
      "\n",
      "Instruction: Describe the structure of an atom.\n",
      "Answer: An atom is a small particle of a chemical element, with a central electron surrounded by other electrons, which form a cloud around the central electron. The cloud of electrons is surrounded by a cloud of positive ions, which make up the原子's positive charge. The positive ions and negative ions are attracted to each other, and the atoms form a cloud of ions and electrons.\n",
      "### 3.Answer:\n",
      " An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom. \n",
      "\n",
      "\n",
      "Instruction: How can we reduce air pollution?\n",
      "Answer: There are several ways to reduce air pollution, including reducing energy consumption, improving transportation, reducing waste and reducing the use of harmful chemicals. Additionally, increasing public awareness and education, implementing policies, and increasing funding for research can also help reduce air pollution.\n",
      "### 4.Answer:\n",
      " There are a number of ways to reduce air pollution, such as shifting to renewable energy sources, encouraging the use of public transportation, prohibiting the burning of fossil fuels, implementing policies to reduce emissions from industrial sources, and implementing vehicle emissions standards. Additionally, individuals can do their part to reduce air pollution by reducing car use, avoiding burning materials such as wood, and changing to energy efficient appliances. \n",
      "\n",
      "\n",
      "Instruction: Describe a time when you had to make a difficult decision.\n",
      "Answer: I had to make a difficult decision when I was in my 20s. I had to choose between my career and my studies. I knew that my studies were more important, but I also knew that I wanted to be a teacher. I had to make a decision based on my values and my future plans. I decided to pursue my studies and became a teacher. It was a difficult decision, but I was determined to make the right choice.\n",
      "### 5.Answer:\n",
      " I had to make a difficult decision when I was working as a project manager at a construction company. I was in charge of a project that needed to be completed by a certain date in order to meet the client’s expectations. However, due to unexpected delays, we were not able to meet the deadline and so I had to make a difficult decision. I decided to extend the deadline, but I had to stretch the team’s resources even further and increase the budget. Although it was a risky decision, I ultimately decided to go ahead with it to ensure that the project was completed on time and that the client’s expectations were met. The project was eventually successfully completed and this was seen as a testament to my leadership and decision-making abilities. \n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from cover_alpaca2jsonl import format_example\n",
    "import json\n",
    "\n",
    "\n",
    "instructions = json.load(open(\"data/alpaca_data.json\"))\n",
    "\n",
    "\n",
    "with torch.no_grad():\n",
    "    for idx, item in enumerate(instructions[:5]):\n",
    "        feature = format_example(item)\n",
    "        input_text = feature[\"context\"]\n",
    "        input_ids = tokenizer.encode(input_text, return_tensors='pt')\n",
    "        out = model.generate(\n",
    "            input_ids=input_ids,\n",
    "            max_length=150,\n",
    "            temperature=0\n",
    "        )\n",
    "        answer = tokenizer.decode(out[0])\n",
    "        print(answer)\n",
    "        item['infer_answer'] = answer\n",
    "        print(f\"### {idx+1}.Answer:\\n\", item.get('output'), '\\n\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "\n",
    "def save_tunable_parameters(model, path):\n",
    "    saved_params = {\n",
    "        k: v.to(\"cpu\")\n",
    "        for k, v in model.named_parameters()\n",
    "        if v.requires_grad\n",
    "    }\n",
    "    torch.save(saved_params, path)\n",
    "\n",
    "\n",
    "save_tunable_parameters(model, os.path.join(\"output\", \"chatglm-lora.pt\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12 (default, Feb  7 2022, 13:32:35) \n[GCC 9.4.0]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "25273a2a68c96ebac13d7fb9e0db516f9be0772777a0507fe06d682a441a3ba7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
